# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc.

# SPDX-License-Identifier: Apache-2.0

from typing import Optional, Tuple

import pytest
import torch
import ttnn

from loguru import logger

from tests.sweep_framework.sweep_utils.roofline_utils import get_run_return
from tests.sweep_framework.sweep_utils.utils import gen_pytest_parametrize_args
from tests.ttnn.utils_for_testing import start_measuring_time, stop_measuring_time
from models.common.utility_functions import torch_random

TIMEOUT = 70

# params contains the shape of the first tensor followed by the second tensor
# Note: the shape of the second tensor starts at int(count / 2). It's easiest
# to reason about if both tensors are the same rank, although some other
# combinations may be valid.
parameters = {
    "pytorch": {
        "params": [
            (1, 1024, 1024, 1024),
            (1, 1024, 1024, 3072),
            (1, 1024, 1024, 32128),
            (1, 1024, 1024, 4096),
            (1, 1024, 1024, 512),
            (1, 10, 10, 128),
            (1, 128, 128, 9216),
            (1, 2048, 2048, 512),
            (1, 3072, 3072, 768),
            (1, 384, 384, 512),
            (1, 4096, 4096, 1024),
            (1, 512, 512, 1024),
            (1, 512, 512, 2048),
            (1, 512, 512, 32128),
            (1, 512, 512, 384),
            (1, 512, 512, 50272),
            (1, 512, 512, 512),
            (1, 768, 768, 3072),
            (1, 768, 768, 32128),
            (1, 768, 768, 50257),
            (1, 768, 768, 512),
            (1, 768, 768, 51865),
            (1, 768, 768, 768),
            (10, 1024, 1024, 1024),
            (10, 1024, 1024, 4096),
            (10, 1, 1, 128),
            (10, 2048, 2048, 512),
            (10, 3072, 3072, 768),
            (10, 4096, 4096, 1024),
            (10, 512, 512, 2048),
            (10, 512, 512, 512),
            (10, 768, 768, 3072),
            (10, 768, 768, 768),
            (1024, 160, 160, 256),
            (1024, 384, 384, 192),
            (1024, 512, 512, 256),
            (1024, 640, 640, 640),
            (128, 1, 1, 9216),
            (15, 1024, 1024, 512),
            (15, 384, 384, 512),
            (15, 512, 512, 1024),
            (15, 512, 512, 384),
            (1500, 768, 768, 768),
            (16384, 32, 32, 256),
            (19, 1024, 1024, 256008),
            (196, 1024, 1024, 512),
            (196, 768, 768, 384),
            (197, 1024, 1024, 1024),
            (197, 768, 768, 768),
            (2, 512, 512, 1),
            (2, 512, 512, 512),
            (2048, 768, 768, 262),
            (225, 512, 512, 12),
            (225, 512, 512, 16),
            (225, 512, 512, 24),
            (225, 512, 512, 32),
            (225, 512, 512, 3),
            (225, 512, 512, 4),
            (225, 512, 512, 6),
            (225, 512, 512, 8),
            (256, 1024, 1024, 512),
            (256, 1280, 1280, 1280),
            (256, 256, 256, 256),
            (256, 768, 768, 384),
            (32, 1536, 1536, 250880),
            (4, 768, 768, 51865),
            (4, 768, 768, 768),
            (4096, 320, 320, 320),
            (4096, 64, 64, 256),
            (45, 768, 768, 50257),
            (45, 768, 768, 768),
            (49, 1536, 1536, 768),
            (49, 2048, 2048, 1024),
            (5, 1024, 1024, 1024),
            (5, 1024, 1024, 3072),
            (59, 1024, 1024, 512),
            (59, 512, 512, 1024),
            (59, 512, 512, 50272),
            (64, 1280, 1280, 1280),
            (64, 1536, 1536, 768),
            (64, 2048, 2048, 1024),
            (7, 18176, 18176, 4544),
            (7, 4544, 4544, 18176),
            (7, 4544, 4544, 4544),
            (7, 4544, 4544, 4672),
            (7, 4544, 4544, 65024),
            (7, 768, 768, 2),
            (768, 196, 196, 384),
            (784, 384, 384, 192),
            (784, 512, 512, 256),
            (9, 768, 768, 1280),
            (9, 768, 768, 320),
            (9, 768, 768, 640),
            (920, 256, 256, 256),
        ],
        "core_grid": [True, False],
        "dtype": [ttnn.float32, ttnn.bfloat16],
        "test_bias": [True, False],
    },
    "gpt": {
        "params": [
            (1, 1, 1, 1, 1, 1, 1, 1),
            (1, 1, 1, 1, 1, 1, 1, 2304),
            (1, 1, 1, 1, 1, 1, 1, 3072),
            (1, 1, 1, 1, 1, 1, 1, 65536),
            (1, 1, 1, 1, 1, 1, 1, 768),
            (1, 1, 1, 1, 1, 1, 1, 96),
            (1, 1, 1, 2304, 1, 1, 2304, 1),
            (1, 1, 1, 2304, 1, 1, 2304, 65536),
            (1, 1, 1, 2304, 1, 1, 2304, 768),
            (1, 1, 1, 3072, 1, 1, 3072, 1),
            (1, 1, 1, 3072, 1, 1, 3072, 65536),
            (1, 1, 1, 3072, 1, 1, 3072, 768),
            (1, 1, 1, 65536, 1, 1, 65536, 2304),
            (1, 1, 1, 65536, 1, 1, 65536, 3072),
            (1, 1, 1, 65536, 1, 1, 65536, 768),
            (1, 1, 1, 65536, 1, 1, 65536, 96),
            (1, 1, 1, 768, 1, 1, 768, 1),
            (1, 1, 1, 768, 1, 1, 768, 1024),
            (1, 1, 1, 768, 1, 1, 768, 2304),
            (1, 1, 1, 768, 1, 1, 768, 3072),
            (1, 1, 1, 768, 1, 1, 768, 65536),
            (1, 1, 1, 768, 1, 1, 768, 768),
            (1, 1, 1, 768, 1, 1, 768, 96),
            (1, 1, 1, 96, 1, 1, 96, 1),
            (1, 1, 1, 96, 1, 1, 96, 65536),
            (1, 1, 1, 96, 1, 1, 96, 768),
            (1, 1, 1024, 768, 1, 1, 768, 1),
            (1, 1, 1024, 768, 1, 1, 768, 1024),
            (1, 1, 1024, 768, 1, 1, 768, 2304),
            (1, 1, 1024, 768, 1, 1, 768, 3072),
            (1, 1, 1024, 768, 1, 1, 768, 65536),
            (1, 1, 1024, 768, 1, 1, 768, 768),
            (1, 1, 1024, 768, 1, 1, 768, 96),
            (1, 1, 2304, 1, 1, 1, 1, 1),
            (1, 1, 2304, 1, 1, 1, 1, 2304),
            (1, 1, 2304, 1, 1, 1, 1, 3072),
            (1, 1, 2304, 1, 1, 1, 1, 65536),
            (1, 1, 2304, 1, 1, 1, 1, 768),
            (1, 1, 2304, 1, 1, 1, 1, 96),
            (1, 1, 2304, 65536, 1, 1, 65536, 1),
            (1, 1, 2304, 65536, 1, 1, 65536, 2304),
            (1, 1, 2304, 65536, 1, 1, 65536, 3072),
            (1, 1, 2304, 65536, 1, 1, 65536, 768),
            (1, 1, 2304, 65536, 1, 1, 65536, 96),
            (1, 1, 2304, 768, 1, 1, 768, 1),
            (1, 1, 2304, 768, 1, 1, 768, 1024),
            (1, 1, 2304, 768, 1, 1, 768, 2304),
            (1, 1, 2304, 768, 1, 1, 768, 3072),
            (1, 1, 2304, 768, 1, 1, 768, 65536),
            (1, 1, 2304, 768, 1, 1, 768, 768),
            (1, 1, 2304, 768, 1, 1, 768, 96),
            (1, 1, 3072, 1, 1, 1, 1, 1),
            (1, 1, 3072, 1, 1, 1, 1, 2304),
            (1, 1, 3072, 1, 1, 1, 1, 3072),
            (1, 1, 3072, 1, 1, 1, 1, 65536),
            (1, 1, 3072, 1, 1, 1, 1, 768),
            (1, 1, 3072, 1, 1, 1, 1, 96),
            (1, 1, 3072, 65536, 1, 1, 65536, 1),
            (1, 1, 3072, 65536, 1, 1, 65536, 2304),
            (1, 1, 3072, 65536, 1, 1, 65536, 3072),
            (1, 1, 3072, 65536, 1, 1, 65536, 768),
            (1, 1, 3072, 65536, 1, 1, 65536, 96),
            (1, 1, 3072, 768, 1, 1, 768, 1),
            (1, 1, 3072, 768, 1, 1, 768, 1024),
            (1, 1, 3072, 768, 1, 1, 768, 2304),
            (1, 1, 3072, 768, 1, 1, 768, 3072),
            (1, 1, 3072, 768, 1, 1, 768, 65536),
            (1, 1, 3072, 768, 1, 1, 768, 768),
            (1, 1, 3072, 768, 1, 1, 768, 96),
            (1, 1, 65536, 1, 1, 1, 1, 1),
            (1, 1, 65536, 1, 1, 1, 1, 2304),
            (1, 1, 65536, 1, 1, 1, 1, 3072),
            (1, 1, 65536, 1, 1, 1, 1, 65536),
            (1, 1, 65536, 1, 1, 1, 1, 768),
            (1, 1, 65536, 1, 1, 1, 1, 96),
            (1, 1, 65536, 2304, 1, 1, 2304, 1),
            (1, 1, 65536, 2304, 1, 1, 2304, 65536),
            (1, 1, 65536, 2304, 1, 1, 2304, 768),
            (1, 1, 65536, 3072, 1, 1, 3072, 1),
            (1, 1, 65536, 3072, 1, 1, 3072, 65536),
            (1, 1, 65536, 3072, 1, 1, 3072, 768),
            (1, 1, 65536, 768, 1, 1, 768, 1),
            (1, 1, 65536, 768, 1, 1, 768, 1024),
            (1, 1, 65536, 768, 1, 1, 768, 2304),
            (1, 1, 65536, 768, 1, 1, 768, 3072),
            (1, 1, 65536, 768, 1, 1, 768, 65536),
            (1, 1, 65536, 768, 1, 1, 768, 768),
            (1, 1, 65536, 768, 1, 1, 768, 96),
            (1, 1, 65536, 96, 1, 1, 96, 65536),
            (1, 1, 65536, 96, 1, 1, 96, 768),
            (1, 1, 768, 1, 1, 1, 1, 1),
            (1, 1, 768, 1, 1, 1, 1, 2304),
            (1, 1, 768, 1, 1, 1, 1, 3072),
            (1, 1, 768, 1, 1, 1, 1, 65536),
            (1, 1, 768, 1, 1, 1, 1, 768),
            (1, 1, 768, 1, 1, 1, 1, 96),
            (1, 1, 768, 1024, 1, 1, 1024, 768),
            (1, 1, 768, 2304, 1, 1, 2304, 1),
            (1, 1, 768, 2304, 1, 1, 2304, 65536),
            (1, 1, 768, 2304, 1, 1, 2304, 768),
            (1, 1, 768, 3072, 1, 1, 3072, 1),
            (1, 1, 768, 3072, 1, 1, 3072, 65536),
            (1, 1, 768, 3072, 1, 1, 3072, 768),
            (1, 1, 768, 65536, 1, 1, 65536, 1),
            (1, 1, 768, 65536, 1, 1, 65536, 2304),
            (1, 1, 768, 65536, 1, 1, 65536, 3072),
            (1, 1, 768, 65536, 1, 1, 65536, 768),
            (1, 1, 768, 65536, 1, 1, 65536, 96),
            (1, 1, 768, 768, 1, 1, 768, 1),
            (1, 1, 768, 768, 1, 1, 768, 1024),
            (1, 1, 768, 768, 1, 1, 768, 2304),
            (1, 1, 768, 768, 1, 1, 768, 3072),
            (1, 1, 768, 768, 1, 1, 768, 65536),
            (1, 1, 768, 768, 1, 1, 768, 768),
            (1, 1, 768, 768, 1, 1, 768, 96),
            (1, 1, 768, 96, 1, 1, 96, 1),
            (1, 1, 768, 96, 1, 1, 96, 65536),
            (1, 1, 768, 96, 1, 1, 96, 768),
            (1, 1, 96, 1, 1, 1, 1, 1),
            (1, 1, 96, 1, 1, 1, 1, 2304),
            (1, 1, 96, 1, 1, 1, 1, 3072),
            (1, 1, 96, 1, 1, 1, 1, 65536),
            (1, 1, 96, 1, 1, 1, 1, 768),
            (1, 1, 96, 1, 1, 1, 1, 96),
            (1, 1, 96, 65536, 1, 1, 65536, 1),
            (1, 1, 96, 65536, 1, 1, 65536, 2304),
            (1, 1, 96, 65536, 1, 1, 65536, 3072),
            (1, 1, 96, 65536, 1, 1, 65536, 768),
            (1, 1, 96, 65536, 1, 1, 65536, 96),
            (1, 1, 96, 768, 1, 1, 768, 1),
            (1, 1, 96, 768, 1, 1, 768, 1024),
            (1, 1, 96, 768, 1, 1, 768, 2304),
            (1, 1, 96, 768, 1, 1, 768, 3072),
            (1, 1, 96, 768, 1, 1, 768, 65536),
            (1, 1, 96, 768, 1, 1, 768, 768),
            (1, 1, 96, 768, 1, 1, 768, 96),
            (1, 64, 1024, 768, 1, 1, 768, 1),
            (1, 64, 1024, 768, 1, 1, 768, 2304),
            (1, 64, 1024, 768, 1, 1, 768, 3072),
            (1, 64, 1024, 768, 1, 1, 768, 65536),
            (1, 64, 1024, 768, 1, 1, 768, 768),
            (1, 64, 1024, 768, 1, 1, 768, 96),
            (1, 64, 768, 1024, 1, 1, 1024, 768),
            (1, 64, 768, 1024, 1, 64, 1024, 768),
            (64, 1, 1, 1024, 1, 1, 1024, 768),
            (64, 1, 1, 1024, 64, 1, 1024, 1),
            (64, 1, 1, 1024, 64, 1, 1024, 2304),
            (64, 1, 1, 1024, 64, 1, 1024, 3072),
            (64, 1, 1, 1024, 64, 1, 1024, 768),
            (64, 1, 1, 1024, 64, 1, 1024, 96),
            (64, 1, 1, 768, 1, 1, 768, 1),
            (64, 1, 1, 768, 1, 1, 768, 1024),
            (64, 1, 1, 768, 1, 1, 768, 2304),
            (64, 1, 1, 768, 1, 1, 768, 3072),
            (64, 1, 1, 768, 1, 1, 768, 65536),
            (64, 1, 1, 768, 1, 1, 768, 768),
            (64, 1, 1, 768, 1, 1, 768, 96),
            (64, 1, 1, 768, 64, 1, 768, 1),
            (64, 1, 1, 768, 64, 1, 768, 1024),
            (64, 1, 1024, 1, 1, 1, 1, 2304),
            (64, 1, 1024, 1, 1, 1, 1, 3072),
            (64, 1, 1024, 1, 1, 1, 1, 768),
            (64, 1, 1024, 1, 1, 1, 1, 96),
            (64, 1, 1024, 1, 64, 1, 1, 1024),
            (64, 1, 1024, 1, 64, 1, 1, 768),
            (64, 1, 1024, 2304, 1, 1, 2304, 65536),
            (64, 1, 1024, 2304, 1, 1, 2304, 768),
            (64, 1, 1024, 2304, 64, 1, 2304, 1024),
            (64, 1, 1024, 3072, 1, 1, 3072, 1),
            (64, 1, 1024, 3072, 1, 1, 3072, 65536),
            (64, 1, 1024, 3072, 1, 1, 3072, 768),
            (64, 1, 1024, 768, 1, 1, 768, 1),
            (64, 1, 1024, 768, 1, 1, 768, 1024),
            (64, 1, 1024, 768, 1, 1, 768, 2304),
            (64, 1, 1024, 768, 1, 1, 768, 3072),
            (64, 1, 1024, 768, 1, 1, 768, 65536),
            (64, 1, 1024, 768, 1, 1, 768, 768),
            (64, 1, 1024, 768, 1, 1, 768, 96),
            (64, 1, 1024, 768, 64, 1, 768, 1024),
            (64, 1, 1024, 96, 1, 1, 96, 65536),
            (64, 1, 1024, 96, 1, 1, 96, 768),
            (64, 1, 1024, 96, 64, 1, 96, 1024),
            (64, 1, 2304, 1024, 1, 1, 1024, 768),
            (64, 1, 2304, 1024, 64, 1, 1024, 1),
            (64, 1, 2304, 1024, 64, 1, 1024, 2304),
            (64, 1, 2304, 1024, 64, 1, 1024, 3072),
            (64, 1, 2304, 1024, 64, 1, 1024, 768),
            (64, 1, 2304, 1024, 64, 1, 1024, 96),
            (64, 1, 3072, 1024, 1, 1, 1024, 768),
            (64, 1, 3072, 1024, 64, 1, 1024, 1),
            (64, 1, 3072, 1024, 64, 1, 1024, 2304),
            (64, 1, 3072, 1024, 64, 1, 1024, 3072),
            (64, 1, 3072, 1024, 64, 1, 1024, 768),
            (64, 1, 3072, 1024, 64, 1, 1024, 96),
            (64, 1, 768, 1, 1, 1, 1, 2304),
            (64, 1, 768, 1, 1, 1, 1, 3072),
            (64, 1, 768, 1, 1, 1, 1, 768),
            (64, 1, 768, 1, 1, 1, 1, 96),
            (64, 1, 768, 1, 64, 1, 1, 768),
            (64, 1, 768, 1024, 1, 1, 1024, 768),
            (64, 1, 768, 1024, 64, 1, 1024, 1),
            (64, 1, 768, 1024, 64, 1, 1024, 2304),
            (64, 1, 768, 1024, 64, 1, 1024, 3072),
            (64, 1, 768, 1024, 64, 1, 1024, 768),
            (64, 1, 768, 1024, 64, 1, 1024, 96),
            (64, 1, 96, 1024, 1, 1, 1024, 768),
            (64, 1, 96, 1024, 64, 1, 1024, 1),
            (64, 1, 96, 1024, 64, 1, 1024, 2304),
            (64, 1, 96, 1024, 64, 1, 1024, 3072),
            (64, 1, 96, 1024, 64, 1, 1024, 768),
            (64, 1, 96, 1024, 64, 1, 1024, 96),
            (64, 12, 1, 1024, 1, 1, 1024, 768),
            (64, 12, 1, 1024, 64, 12, 1024, 1),
            (64, 12, 1, 1024, 64, 12, 1024, 1024),
            (64, 12, 1, 1024, 64, 12, 1024, 64),
            (64, 12, 1024, 1, 1, 1, 1, 1),
            (64, 12, 1024, 1, 1, 1, 1, 2304),
            (64, 12, 1024, 1, 1, 1, 1, 3072),
            (64, 12, 1024, 1, 1, 1, 1, 768),
            (64, 12, 1024, 1, 1, 1, 1, 96),
            (64, 12, 1024, 1, 64, 12, 1, 1024),
            (64, 12, 1024, 1024, 1, 1, 1024, 768),
            (64, 12, 1024, 1024, 64, 12, 1024, 1),
            (64, 12, 1024, 1024, 64, 12, 1024, 1024),
            (64, 12, 1024, 1024, 64, 12, 1024, 64),
            (64, 12, 1024, 64, 64, 12, 64, 1024),
            (64, 12, 64, 1024, 1, 1, 1024, 768),
            (64, 12, 64, 1024, 64, 12, 1024, 1),
            (64, 12, 64, 1024, 64, 12, 1024, 1024),
            (64, 12, 64, 1024, 64, 12, 1024, 64),
        ],
        "core_grid": [True, False],
        "dtype": [ttnn.float32, ttnn.bfloat16],
        "test_bias": [True, False],
    },
    "forge": {
        "params": [
            (
                1,
                1024,
                1024,
                1000,
            ),
            (
                1,
                1024,
                1024,
                1024,
            ),
            (
                1,
                1024,
                1024,
                32128,
            ),
            (
                1,
                1024,
                1024,
                4096,
            ),
            (
                1,
                1024,
                1024,
                512,
            ),
            (
                1,
                1024,
                160,
                1,
                160,
                256,
            ),
            (
                1,
                1024,
                640,
                1,
                640,
                160,
            ),
            (
                1,
                12,
                12,
                3,
            ),
            (
                1,
                12,
                12,
                64,
            ),
            (
                1,
                1200,
                1280,
                1,
                1280,
                320,
            ),
            (
                1,
                128,
                128,
                10,
            ),
            (
                1,
                128,
                128,
                64,
            ),
            (
                1,
                128,
                128,
                784,
            ),
            (
                1,
                1280,
                1280,
                1000,
            ),
            (
                1,
                1280,
                1280,
                1280,
            ),
            (
                1,
                1280,
                1280,
                320,
            ),
            (
                1,
                1280,
                1280,
                640,
            ),
            (
                1,
                1536,
                1536,
                3129,
            ),
            (
                1,
                16384,
                128,
                1,
                128,
                32,
            ),
            (
                1,
                16384,
                256,
                1,
                256,
                32,
            ),
            (
                1,
                16384,
                32,
                1,
                32,
                256,
            ),
            (
                1,
                19200,
                256,
                1,
                256,
                64,
            ),
            (
                1,
                19200,
                300,
                1,
                300,
                64,
            ),
            (
                1,
                19200,
                64,
                1,
                64,
                300,
            ),
            (
                1,
                2048,
                2048,
                1000,
            ),
            (
                1,
                2048,
                2048,
                512,
            ),
            (
                1,
                256,
                1024,
                1,
                1024,
                256,
            ),
            (
                1,
                256,
                256,
                1,
                256,
                256,
            ),
            (
                1,
                3,
                3,
                12,
            ),
            (
                1,
                300,
                2048,
                1,
                2048,
                512,
            ),
            (
                1,
                3072,
                3072,
                768,
            ),
            (
                1,
                32,
                1,
                1,
                1,
                7,
            ),
            (
                1,
                320,
                320,
                1280,
            ),
            (
                1,
                384,
                384,
                512,
            ),
            (
                1,
                4096,
                256,
                1,
                256,
                64,
            ),
            (
                1,
                4096,
                4096,
                1024,
            ),
            (
                1,
                4096,
                64,
                1,
                64,
                256,
            ),
            (
                1,
                4800,
                512,
                1,
                512,
                128,
            ),
            (
                1,
                512,
                512,
                1000,
            ),
            (
                1,
                512,
                512,
                1024,
            ),
            (
                1,
                512,
                512,
                2048,
            ),
            (
                1,
                512,
                512,
                32128,
            ),
            (
                1,
                512,
                512,
                384,
            ),
            (
                1,
                512,
                512,
                512,
            ),
            (
                1,
                64,
                1,
                1,
                1,
                32,
            ),
            (
                1,
                64,
                64,
                12,
            ),
            (
                1,
                64,
                64,
                128,
            ),
            (
                1,
                768,
                768,
                1,
            ),
            (
                1,
                768,
                768,
                1000,
            ),
            (
                1,
                768,
                768,
                1536,
            ),
            (
                1,
                768,
                768,
                2,
            ),
            (
                1,
                768,
                768,
                3,
            ),
            (
                1,
                768,
                768,
                3072,
            ),
            (
                1,
                768,
                768,
                32128,
            ),
            (
                1,
                768,
                768,
                512,
            ),
            (
                1,
                768,
                768,
                768,
            ),
            (
                1,
                784,
                784,
                128,
            ),
            (
                1,
                9216,
                9216,
                128,
            ),
            (
                10,
                1024,
                1024,
                1024,
            ),
            (
                10,
                1024,
                1024,
                4096,
            ),
            (
                10,
                2048,
                2048,
                512,
            ),
            (
                10,
                3072,
                3072,
                768,
            ),
            (
                10,
                4096,
                4096,
                1024,
            ),
            (
                10,
                512,
                512,
                2048,
            ),
            (
                10,
                512,
                512,
                512,
            ),
            (
                10,
                768,
                768,
                250002,
            ),
            (
                10,
                768,
                768,
                3072,
            ),
            (
                10,
                768,
                768,
                768,
            ),
            (
                100,
                192,
                192,
                4,
            ),
            (
                100,
                192,
                192,
                92,
            ),
            (
                100,
                2048,
                2048,
                256,
            ),
            (
                100,
                256,
                256,
                2048,
            ),
            (
                100,
                256,
                256,
                256,
            ),
            (
                1024,
                160,
                160,
                160,
            ),
            (
                1024,
                160,
                160,
                640,
            ),
            (
                1024,
                2560,
                2560,
                640,
            ),
            (
                1024,
                640,
                640,
                5120,
            ),
            (
                1024,
                640,
                640,
                640,
            ),
            (
                12,
                1,
                1,
                12,
                1,
                64,
            ),
            (
                12,
                1,
                10,
                12,
                10,
                64,
            ),
            (
                12,
                1,
                64,
                12,
                64,
                1,
            ),
            (
                12,
                1,
                64,
                12,
                64,
                10,
            ),
            (
                12,
                10,
                64,
                12,
                64,
                10,
            ),
            (
                12,
                12,
                12,
                12,
                12,
                64,
            ),
            (
                12,
                12,
                64,
                12,
                64,
                12,
            ),
            (
                12,
                128,
                128,
                768,
            ),
            (
                12,
                14,
                14,
                12,
                14,
                64,
            ),
            (
                12,
                14,
                64,
                12,
                64,
                14,
            ),
            (
                12,
                16,
                16,
                12,
                16,
                64,
            ),
            (
                12,
                16,
                64,
                12,
                64,
                16,
            ),
            (
                12,
                197,
                197,
                12,
                197,
                64,
            ),
            (
                12,
                197,
                64,
                12,
                64,
                197,
            ),
            (
                12,
                201,
                201,
                12,
                201,
                64,
            ),
            (
                12,
                201,
                64,
                12,
                64,
                201,
            ),
            (
                12,
                25,
                25,
                12,
                25,
                64,
            ),
            (
                12,
                25,
                64,
                12,
                64,
                25,
            ),
            (
                12,
                3072,
                3072,
                768,
            ),
            (
                12,
                50,
                50,
                12,
                50,
                64,
            ),
            (
                12,
                50,
                64,
                12,
                64,
                50,
            ),
            (
                12,
                7,
                64,
                12,
                64,
                7,
            ),
            (
                12,
                7,
                7,
                12,
                7,
                64,
            ),
            (
                12,
                768,
                768,
                2,
            ),
            (
                12,
                768,
                768,
                3072,
            ),
            (
                12,
                768,
                768,
                768,
            ),
            (
                12,
                8,
                64,
                12,
                64,
                8,
            ),
            (
                12,
                8,
                8,
                12,
                8,
                64,
            ),
            (
                12,
                9,
                64,
                12,
                64,
                9,
            ),
            (
                12,
                9,
                9,
                12,
                9,
                64,
            ),
            (
                1200,
                320,
                320,
                1280,
            ),
            (
                1200,
                320,
                320,
                320,
            ),
            (
                14,
                128,
                128,
                768,
            ),
            (
                14,
                2048,
                2048,
                512,
            ),
            (
                14,
                3072,
                3072,
                768,
            ),
            (
                14,
                512,
                512,
                2048,
            ),
            (
                14,
                512,
                512,
                512,
            ),
            (
                14,
                768,
                768,
                2,
            ),
            (
                14,
                768,
                768,
                3072,
            ),
            (
                14,
                768,
                768,
                768,
            ),
            (
                1445,
                192,
                192,
                192,
            ),
            (
                1445,
                192,
                192,
                768,
            ),
            (
                1445,
                768,
                768,
                192,
            ),
            (
                15,
                1024,
                1024,
                512,
            ),
            (
                15,
                384,
                384,
                512,
            ),
            (
                15,
                512,
                512,
                1024,
            ),
            (
                15,
                512,
                512,
                384,
            ),
            (
                16,
                1,
                1,
                16,
                1,
                64,
            ),
            (
                16,
                1,
                10,
                16,
                10,
                64,
            ),
            (
                16,
                1,
                64,
                16,
                64,
                1,
            ),
            (
                16,
                1,
                64,
                16,
                64,
                10,
            ),
            (
                16,
                10,
                10,
                16,
                10,
                64,
            ),
            (
                16,
                10,
                64,
                16,
                64,
                10,
            ),
            (
                16,
                19,
                19,
                16,
                19,
                64,
            ),
            (
                16,
                19,
                64,
                16,
                64,
                19,
            ),
            (
                16,
                197,
                197,
                16,
                197,
                64,
            ),
            (
                16,
                197,
                64,
                16,
                64,
                197,
            ),
            (
                16,
                256,
                256,
                16,
                256,
                64,
            ),
            (
                16,
                256,
                64,
                16,
                64,
                256,
            ),
            (
                16,
                3072,
                3072,
                768,
            ),
            (
                16,
                32,
                32,
                16,
                32,
                96,
            ),
            (
                16,
                5,
                5,
                16,
                5,
                64,
            ),
            (
                16,
                5,
                64,
                16,
                64,
                5,
            ),
            (
                16,
                6,
                6,
                16,
                6,
                64,
            ),
            (
                16,
                6,
                64,
                16,
                64,
                6,
            ),
            (
                16,
                7,
                64,
                16,
                64,
                7,
            ),
            (
                16,
                7,
                7,
                16,
                7,
                64,
            ),
            (
                16,
                768,
                768,
                3072,
            ),
            (
                16,
                768,
                768,
                768,
            ),
            (
                16,
                9,
                128,
                16,
                128,
                9,
            ),
            (
                16,
                9,
                64,
                16,
                64,
                9,
            ),
            (
                16,
                9,
                9,
                16,
                9,
                128,
            ),
            (
                16,
                9,
                9,
                16,
                9,
                64,
            ),
            (
                16384,
                32,
                32,
                128,
            ),
            (
                16384,
                32,
                32,
                32,
            ),
            (
                19,
                1024,
                1024,
                1024,
            ),
            (
                19,
                1024,
                1024,
                256008,
            ),
            (
                19,
                1024,
                1024,
                4096,
            ),
            (
                19,
                4096,
                4096,
                1024,
            ),
            (
                19200,
                64,
                64,
                256,
            ),
            (
                19200,
                64,
                64,
                64,
            ),
            (
                197,
                1024,
                1024,
                1024,
            ),
            (
                197,
                1024,
                1024,
                4096,
            ),
            (
                197,
                3072,
                3072,
                768,
            ),
            (
                197,
                4096,
                4096,
                1024,
            ),
            (
                197,
                768,
                768,
                3072,
            ),
            (
                197,
                768,
                768,
                768,
            ),
            (
                2,
                4096,
                256,
                2,
                256,
                32,
            ),
            (
                2,
                4096,
                32,
                2,
                32,
                256,
            ),
            (
                2,
                4800,
                300,
                2,
                300,
                64,
            ),
            (
                2,
                4800,
                64,
                2,
                64,
                300,
            ),
            (
                2,
                512,
                512,
                1,
            ),
            (
                2,
                512,
                512,
                512,
            ),
            (
                201,
                3072,
                3072,
                768,
            ),
            (
                201,
                768,
                768,
                3072,
            ),
            (
                201,
                768,
                768,
                768,
            ),
            (
                2048,
                768,
                768,
                1280,
            ),
            (
                2048,
                768,
                768,
                256,
            ),
            (
                2048,
                768,
                768,
                262,
            ),
            (
                2048,
                768,
                768,
                768,
            ),
            (
                25,
                3072,
                3072,
                768,
            ),
            (
                25,
                768,
                768,
                2,
            ),
            (
                25,
                768,
                768,
                3072,
            ),
            (
                25,
                768,
                768,
                768,
            ),
            (
                256,
                1024,
                1024,
                1024,
            ),
            (
                256,
                1024,
                1024,
                2,
            ),
            (
                256,
                1024,
                1024,
                4096,
            ),
            (
                256,
                1280,
                1280,
                10240,
            ),
            (
                256,
                1280,
                1280,
                1280,
            ),
            (
                256,
                1280,
                1280,
                256,
            ),
            (
                256,
                1280,
                1280,
                768,
            ),
            (
                256,
                160,
                160,
                160,
            ),
            (
                256,
                256,
                256,
                1024,
            ),
            (
                256,
                256,
                256,
                256,
            ),
            (
                256,
                256,
                256,
                512,
            ),
            (
                256,
                32,
                32,
                32,
            ),
            (
                256,
                4096,
                4096,
                1024,
            ),
            (
                256,
                512,
                512,
                256,
            ),
            (
                256,
                5120,
                5120,
                1280,
            ),
            (
                256,
                64,
                64,
                64,
            ),
            (
                256,
                768,
                768,
                512,
            ),
            (
                3,
                1445,
                1445,
                3,
                1445,
                64,
            ),
            (
                3,
                1445,
                64,
                3,
                64,
                1445,
            ),
            (
                300,
                128,
                128,
                128,
            ),
            (
                300,
                320,
                320,
                320,
            ),
            (
                300,
                512,
                512,
                2048,
            ),
            (
                300,
                512,
                512,
                512,
            ),
            (
                300,
                64,
                64,
                64,
            ),
            (
                32,
                11008,
                11008,
                4096,
            ),
            (
                32,
                1536,
                1536,
                1536,
            ),
            (
                32,
                1536,
                1536,
                250880,
            ),
            (
                32,
                1536,
                1536,
                4608,
            ),
            (
                32,
                1536,
                1536,
                6144,
            ),
            (
                32,
                32,
                128,
                32,
                128,
                32,
            ),
            (
                32,
                32,
                32,
                32,
                32,
                128,
            ),
            (
                32,
                4096,
                4096,
                11008,
            ),
            (
                32,
                4096,
                4096,
                32000,
            ),
            (
                32,
                4096,
                4096,
                4096,
            ),
            (
                32,
                6144,
                6144,
                1536,
            ),
            (
                4096,
                1280,
                1280,
                320,
            ),
            (
                4096,
                320,
                320,
                2560,
            ),
            (
                4096,
                320,
                320,
                320,
            ),
            (
                4096,
                64,
                64,
                256,
            ),
            (
                4096,
                64,
                64,
                64,
            ),
            (
                4800,
                128,
                128,
                128,
            ),
            (
                4800,
                128,
                128,
                512,
            ),
            (
                5,
                1024,
                1024,
                1024,
            ),
            (
                5,
                1024,
                1024,
                3072,
            ),
            (
                5,
                1024,
                1024,
                4096,
            ),
            (
                5,
                1024,
                1024,
                51200,
            ),
            (
                5,
                1024,
                256,
                5,
                256,
                32,
            ),
            (
                5,
                1024,
                32,
                5,
                32,
                256,
            ),
            (
                5,
                1200,
                300,
                5,
                300,
                64,
            ),
            (
                5,
                1200,
                64,
                5,
                64,
                300,
            ),
            (
                5,
                4096,
                4096,
                1024,
            ),
            (
                50,
                3072,
                3072,
                768,
            ),
            (
                50,
                768,
                768,
                3072,
            ),
            (
                50,
                768,
                768,
                768,
            ),
            (
                6,
                1,
                1,
                6,
                1,
                64,
            ),
            (
                6,
                1,
                15,
                6,
                15,
                64,
            ),
            (
                6,
                1,
                64,
                6,
                64,
                1,
            ),
            (
                6,
                1,
                64,
                6,
                64,
                15,
            ),
            (
                6,
                100,
                256,
                6,
                256,
                256,
            ),
            (
                6,
                100,
                256,
                6,
                256,
                92,
            ),
            (
                6,
                1024,
                1024,
                1024,
            ),
            (
                6,
                1024,
                1024,
                4096,
            ),
            (
                6,
                1024,
                1024,
                512,
            ),
            (
                6,
                15,
                15,
                6,
                15,
                64,
            ),
            (
                6,
                15,
                64,
                6,
                64,
                15,
            ),
            (
                6,
                4096,
                4096,
                1024,
            ),
            (
                6,
                512,
                512,
                1024,
            ),
            (
                6,
                512,
                512,
                50272,
            ),
            (
                600,
                256,
                256,
                256,
            ),
            (
                600,
                256,
                256,
                4,
            ),
            (
                64,
                1280,
                1280,
                10240,
            ),
            (
                64,
                1280,
                1280,
                1280,
            ),
            (
                64,
                5120,
                5120,
                1280,
            ),
            (
                64,
                9,
                64,
                64,
                64,
                9,
            ),
            (
                64,
                9,
                9,
                64,
                9,
                64,
            ),
            (
                7,
                18176,
                18176,
                4544,
            ),
            (
                7,
                3072,
                3072,
                768,
            ),
            (
                7,
                4544,
                4544,
                18176,
            ),
            (
                7,
                4544,
                4544,
                4544,
            ),
            (
                7,
                4544,
                4544,
                4672,
            ),
            (
                7,
                4544,
                4544,
                65024,
            ),
            (
                7,
                768,
                768,
                2,
            ),
            (
                7,
                768,
                768,
                2304,
            ),
            (
                7,
                768,
                768,
                3072,
            ),
            (
                7,
                768,
                768,
                768,
            ),
            (
                71,
                7,
                64,
                71,
                64,
                7,
            ),
            (
                71,
                7,
                7,
                71,
                7,
                64,
            ),
            (
                8,
                1,
                1,
                8,
                1,
                64,
            ),
            (
                8,
                1,
                10,
                8,
                10,
                64,
            ),
            (
                8,
                1,
                64,
                8,
                64,
                1,
            ),
            (
                8,
                1,
                64,
                8,
                64,
                10,
            ),
            (
                8,
                10,
                10,
                8,
                10,
                64,
            ),
            (
                8,
                10,
                64,
                8,
                64,
                10,
            ),
            (
                8,
                100,
                100,
                8,
                100,
                32,
            ),
            (
                8,
                100,
                32,
                8,
                32,
                100,
            ),
            (
                8,
                100,
                32,
                8,
                32,
                920,
            ),
            (
                8,
                100,
                920,
                8,
                920,
                32,
            ),
            (
                8,
                1024,
                1024,
                8,
                1024,
                80,
            ),
            (
                8,
                1024,
                80,
                8,
                80,
                1024,
            ),
            (
                8,
                1024,
                80,
                8,
                80,
                9,
            ),
            (
                8,
                1024,
                9,
                8,
                9,
                80,
            ),
            (
                8,
                2048,
                256,
                8,
                256,
                96,
            ),
            (
                8,
                2048,
                32,
                8,
                32,
                256,
            ),
            (
                8,
                256,
                160,
                8,
                160,
                256,
            ),
            (
                8,
                256,
                160,
                8,
                160,
                9,
            ),
            (
                8,
                256,
                2048,
                8,
                2048,
                160,
            ),
            (
                8,
                256,
                256,
                8,
                256,
                160,
            ),
            (
                8,
                256,
                256,
                8,
                256,
                32,
            ),
            (
                8,
                256,
                32,
                8,
                32,
                2048,
            ),
            (
                8,
                256,
                32,
                8,
                32,
                256,
            ),
            (
                8,
                256,
                9,
                8,
                9,
                160,
            ),
            (
                8,
                300,
                300,
                8,
                300,
                64,
            ),
            (
                8,
                300,
                64,
                8,
                64,
                300,
            ),
            (
                8,
                4096,
                40,
                8,
                40,
                4096,
            ),
            (
                8,
                4096,
                40,
                8,
                40,
                9,
            ),
            (
                8,
                4096,
                4096,
                8,
                4096,
                40,
            ),
            (
                8,
                4096,
                9,
                8,
                9,
                40,
            ),
            (
                8,
                64,
                160,
                8,
                160,
                64,
            ),
            (
                8,
                64,
                160,
                8,
                160,
                9,
            ),
            (
                8,
                64,
                64,
                8,
                64,
                160,
            ),
            (
                8,
                64,
                9,
                8,
                9,
                160,
            ),
            (
                8,
                920,
                32,
                8,
                32,
                920,
            ),
            (
                8,
                920,
                920,
                8,
                920,
                32,
            ),
            (
                9,
                1024,
                1024,
                1024,
            ),
            (
                9,
                1024,
                1024,
                128,
            ),
            (
                9,
                1024,
                1024,
                4096,
            ),
            (
                9,
                128,
                128,
                1024,
            ),
            (
                9,
                128,
                128,
                2048,
            ),
            (
                9,
                128,
                128,
                30000,
            ),
            (
                9,
                128,
                128,
                4096,
            ),
            (
                9,
                128,
                128,
                768,
            ),
            (
                9,
                16384,
                16384,
                4096,
            ),
            (
                9,
                2048,
                2048,
                128,
            ),
            (
                9,
                2048,
                2048,
                2048,
            ),
            (
                9,
                2048,
                2048,
                8192,
            ),
            (
                9,
                3072,
                3072,
                768,
            ),
            (
                9,
                4096,
                4096,
                1024,
            ),
            (
                9,
                4096,
                4096,
                128,
            ),
            (
                9,
                4096,
                4096,
                16384,
            ),
            (
                9,
                4096,
                4096,
                4096,
            ),
            (
                9,
                768,
                768,
                128,
            ),
            (
                9,
                768,
                768,
                1280,
            ),
            (
                9,
                768,
                768,
                3072,
            ),
            (
                9,
                768,
                768,
                320,
            ),
            (
                9,
                768,
                768,
                640,
            ),
            (
                9,
                768,
                768,
                768,
            ),
            (
                9,
                8192,
                8192,
                2048,
            ),
            (
                920,
                1,
                256,
                920,
                256,
                256,
            ),
            (
                920,
                2048,
                2048,
                256,
            ),
            (
                920,
                256,
                256,
                2048,
            ),
            (
                920,
                256,
                256,
                256,
            ),
        ],
        "core_grid": [True, False],
        "dtype": [ttnn.float32, ttnn.bfloat16],
        "test_bias": [True, False],
    },
}


# Invalidate vector is called during the generation phase where each vector will be passed in.
# If invalidated, the vector will still be stored but will be skipped.
# Returns False, None if the vector is valid, and True, str with a reason for invalidation if it is invalid.
def invalidate_vector(test_vector) -> Tuple[bool, Optional[str]]:
    # Cannot have bias and batch. If only four params, two input tensors have a dimension of 2 and cannot be batched.
    if test_vector["test_bias"] and len(test_vector["params"]) > 4 and test_vector["params"][0] > 1:
        return True, "Batched input not supported when bias exists"
    return False, None


def run_matmul(device, params, core_grid, dtype, test_bias):
    # Cannot have bias and batch. If only four params, two input tensors have a dimension of 2 and cannot be batched.
    if test_bias and len(params) > 4 and params[0] > 1:
        pytest.skip("Batched input not supported when bias exists")
    if core_grid == False:
        grid = None
    else:
        grid = device.core_grid
    if dtype == ttnn.bfloat16:
        compute_kernel_config = None
    else:
        compute_kernel_config = ttnn.WormholeComputeKernelConfig(
            math_fidelity=ttnn.MathFidelity.HiFi2,
            math_approx_mode=False,
            fp32_dest_acc_en=True,
            packer_l1_acc=True,
        )

    count = len(params)
    half = int(count / 2)
    shape0 = params[0:half]
    shape1 = params[half:count]
    shape2 = [1 if i < (half - 1) else shape1[-1] for i in range(half)]
    torch_input_tensor0 = torch.rand(shape0, dtype=torch.float32)
    torch_input_tensor1 = torch.rand(shape1, dtype=torch.float32)
    torch_input_tensor2 = torch.rand(shape2, dtype=torch.float32)
    torch_output_tensor = torch.matmul(torch_input_tensor0, torch_input_tensor1)
    if test_bias:
        torch_output_tensor += torch_input_tensor2

    input_tensor0 = ttnn.from_torch(torch_input_tensor0, dtype=dtype, layout=ttnn.TILE_LAYOUT, device=device)
    input_tensor1 = ttnn.from_torch(torch_input_tensor1, dtype=dtype, layout=ttnn.TILE_LAYOUT, device=device)
    input_tensor2 = ttnn.from_torch(torch_input_tensor2, dtype=dtype, layout=ttnn.TILE_LAYOUT, device=device)

    start_time = start_measuring_time()
    if test_bias:
        op_output_tensor = ttnn.linear(
            input_tensor0,
            input_tensor1,
            core_grid=grid,
            compute_kernel_config=compute_kernel_config,
            bias=input_tensor2,
        )
    else:
        op_output_tensor = ttnn.matmul(
            input_tensor0, input_tensor1, core_grid=grid, compute_kernel_config=compute_kernel_config
        )
    output_tensor = ttnn.to_torch(op_output_tensor)
    e2e_perf = stop_measuring_time(start_time)
    expected_pcc = 0.99
    tensors = [input_tensor0, input_tensor1, op_output_tensor]
    if test_bias:
        tensors.append(input_tensor2)
    flop_counts = list(shape0) + [2, shape1[-1]]  # shape0: all batch dimensions, m, k; shape1[-1]: n
    return get_run_return(torch_output_tensor, output_tensor, expected_pcc, tensors, e2e_perf, flop_counts)


@pytest.mark.parametrize(**gen_pytest_parametrize_args(parameters))
def test_trace(device, params, core_grid, dtype, test_bias):
    (result, msg), e2e_perf = run_matmul(device, params, core_grid, dtype, test_bias)
    assert result, msg
    logger.info(f"e2e_perf: {e2e_perf}")


def run(
    params,
    core_grid,
    dtype,
    test_bias,
    *,
    device,
) -> list:
    return run_matmul(device, params, core_grid, dtype, test_bias)
